load("//bazel:tachyon.bzl", "if_aarch64", "if_has_avx512", "if_x86_64")
load("//bazel:tachyon_cc.bzl", "tachyon_avx512_defines", "tachyon_cc_library")
load("//tachyon/math/finite_fields/generator/prime_field_generator:build_defs.bzl", "generate_prime_fields")

package(default_visibility = ["//visibility:public"])

generate_prime_fields(
    name = "mersenne31",
    class_name = "Mersenne31",
    # 2³¹ - 1
    # Hex: 0x7fffffff
    modulus = "2147483647",
    namespace = "tachyon::math",
    reduce64 = "return ((((v >> kModulusBits) + v + 1) >> kModulusBits) + v) & kModulus;",
    use_montgomery = False,
)

tachyon_cc_library(
    name = "packed_mersenne31",
    hdrs = ["packed_mersenne31.h"],
    defines = tachyon_avx512_defines(),
    deps = [
        ":packed_mersenne31_neon",
        "//tachyon/build:build_config",
        "//tachyon/math/finite_fields:finite_field_traits",
        "//tachyon/math/matrix:prime_field_num_traits",
    ] + if_has_avx512(
        [":packed_mersenne31_avx512"],
        [":packed_mersenne31_avx2"],
    ),
)

tachyon_cc_library(
    name = "packed_mersenne31_avx2",
    srcs = if_x86_64(["packed_mersenne31_avx2.cc"]),
    hdrs = if_x86_64(["packed_mersenne31_avx2.h"]),
    copts = if_x86_64(["-mavx2"]),
    deps = [
        ":mersenne31",
        "//tachyon:export",
        "//tachyon/math/finite_fields:packed_prime_field32_avx2",
        "//tachyon/math/finite_fields:packed_prime_field_base",
    ],
)

tachyon_cc_library(
    name = "packed_mersenne31_avx512",
    srcs = if_x86_64(["packed_mersenne31_avx512.cc"]),
    hdrs = if_x86_64(["packed_mersenne31_avx512.h"]),
    copts = if_x86_64([
        "-mavx512f",
        # NOTE(chokobole): See https://gitlab.com/libeigen/eigen/-/blob/0b51f76/Eigen/src/Core/util/ConfigureVectorization.h#L248.
        "-mfma",
    ]),
    deps = [
        ":mersenne31",
        "//tachyon:export",
        "//tachyon/math/finite_fields:packed_prime_field32_avx512",
        "//tachyon/math/finite_fields:packed_prime_field_base",
    ],
)

tachyon_cc_library(
    name = "packed_mersenne31_neon",
    srcs = if_aarch64(["packed_mersenne31_neon.cc"]),
    hdrs = if_aarch64(["packed_mersenne31_neon.h"]),
    deps = [
        ":mersenne31",
        "//tachyon:export",
        "//tachyon/math/finite_fields:packed_prime_field32_neon",
        "//tachyon/math/finite_fields:packed_prime_field_base",
    ],
)
